#include "bottom.h"

namespace lp {

	namespace {
		// TODO: put this as argument to mask_bottom and unmask_bottom
		std::map<const Functor*,Functor> rem;
	}

	void reset_mask_bottom()
	{
		rem.clear();
	}

	void mask_bottom(Functor& bottom, const bitstring& mask)
	{
		// Hiding is done by simply turning the subgoals into 'true' and storing the addresses outside
		assert(bottom.body_size() == mask.size() || !(std::cerr << bottom.body_size() << " != " << mask.size() << "\n"));
		int k = 0;
		for (auto i = bottom.body_begin(); i != bottom.body_end(); ++i,++k) {
			Functor& lit = *i;
			if (mask[k] == 0) {
				const Functor* ptr = &lit;
				rem.insert(std::make_pair(ptr, std::move(lit)));
				lit = Functor(true_id);
			} else {
				assert(mask[k] == 1);
				auto at = rem.find(&lit);
				if (at != rem.end()) {
					lit = std::move(at->second);
					rem.erase(at);
				}
			}
		}
	}

	void unmask_bottom(Functor& bottom)
	{
		int k = 0;
		for (auto i = bottom.body_begin(); i != bottom.body_end(); ++i,++k) {
			Functor& lit = *i;
			auto at = rem.find(&lit);
			if (at != rem.end()) {
				lit = std::move(at->second);
				rem.erase(at);
			}
		}
		rem.clear();
	}

	inline bool check_type_m(const Functor& f, id_type type, Program& kb, Program::tmemo_t& tmemo)
	{
		const int max_size = kb.params.force_int(parameters::memo_types);
		// Only memoize when f is smaller than maxsize nodes
		int s = 0;
		for (auto i = f.begin(); i != f.end() && s < max_size; ++i,++s);
		if (s < max_size) {
			auto fat = tmemo.find(f);
			if (fat != tmemo.end()) {
				auto iat = fat->second.find(type);
				if (iat != fat->second.end()) {
					return iat->second;
				}
			}
			// Else: memoize
			const bool ans = kb.check_type(f,type);
			if (fat == tmemo.end()) {
				std::map<id_type,bool> imap;
				imap.insert(std::make_pair(type,ans));
				tmemo.insert(std::make_pair(f,std::move(imap)));
			} else {
				fat->second.insert(std::make_pair(type,ans));
			}
			return ans;
		} else {
			return kb.check_type(f,type);
		}
	}

	bool constraint_satisfied(
		const Functor& lit,
		const std::multimap<sign_t,const Functor*>& bmap,
		const implication& ip)
	{
		const Functor& q = ip.cons;
		const auto& ante = ip.ante;

		if (q.signature() != lit.signature()) return false;

		// Conditional constraint
		std::map<id_type,id_type> var2fun; // TODO: reuse this mapping
		auto get_fun = [&](id_type id){
			auto at = var2fun.find(id);
			if (at == var2fun.end()) at = var2fun.insert(std::make_pair(id,fmap.new_fun_id())).first;
			return at->second;
		};

		Subst subs;
		// Prepare subs with X -> _x bindings to ground bottom clause
		for (const Functor& n : lit) { 
			if (n.is_variable() && !subs.get(n.id())) {
				subs.steal(n.id(),new Functor(get_fun(n.id())));
			}
		}
		if (unify(&q,&lit,subs)) {
			bool fail = false;
			for (const auto& f : ante) {
				auto rng = bmap.equal_range(f.signature());
				auto j = rng.first;
				for ( ; j != rng.second; ++j) {
					// Prepare subs with X -> _x bindings to ground bottom clause
					std::set<id_type> added;
					for (const Functor& n : *j->second) { 
						if (n.is_variable() && !subs.get(n.id())) {
							subs.steal(n.id(),new Functor(get_fun(n.id())));
							added.insert(n.id());
						}
					}
					if (unify(&f,j->second,subs)) {
						//std::cerr << "constraint: ... " << lit << " ,..., " << *j->second << "\n";
						//std::cerr << "  unifies with: " << f << " prevents " << q << "\n";
						break;
					}
					// Remove artifical substitutions
					for (id_type id : added) subs.erase(id);
				}
				if (j == rng.second) {
					// not matched, fail
					return false;
				}
			}
			return true;
		}
		return false;
	}



	void print_candidate(std::ostream& os, const Functor& c)
	{
		// Print candidate without true/0 goals
		std::map<id_type,std::string> vm;

		std::map<id_type,int> vcount;
		for (auto& n : c) if (n.is_variable()) ++vcount[n.id()];
		for (auto& p : vcount) if (p.second == 1) vm[p.first] = "_";
		string vn = "";

		if (c.head()) {
			c.head()->print_infix(os,vn,vm);
		}


		auto i = c.body_begin();
		// Do we have a valid body literal?
		while (i != c.body_end() && i->is_constant(true_id)) ++i;
		if (i != c.body_end()) {
			if (c.head()) os << " :- ";
			i->print_infix(os,vn,vm);
			for (++i ; i != c.body_end(); ++i) {
				if (!i->is_constant(true_id)) {
					os << ", ";
					i->print_infix(os,vn,vm);
				}
			}
		}
	}

	bool is_subset(const bitstring& p, const bitstring& q)
	{
		assert(p.size() == q.size());
		for (decltype(p.size()) k = 0; k < p.size(); ++k) {
			if (p[k] == 1 && q[k] == 0) return false;
		}
		return true;
	}


	//=============== Termindex members ===============//

	void Termindex::print(ostream& os) const 
	{
		os << "Mapping:\n";
		for (auto p = mapping.begin(); p != mapping.end(); ++p) {
			os << " " << p->first << " -> ";
			for (auto v = p->second.begin(); v != p->second.end(); ++v) {
				os << *v << "\n";
			}
		}
		os << "Instantiated:\n";
		for_each(instantiated.begin(), instantiated.end(), [&](const Functor& f) {
			os << " " << f << "\n";
		});
	}

	void Program::instantiate_output(
		Functor& build,
		const Mode& mode,
		Termindex& termindex,
		std::list<Functor>& botv,
		std::vector<Mode>& modemap,
		std::unordered_set<Functor,functor_hash_variant,functor_req>& qterms,
		std::set<const Functor*,functor_ptr_less>& insts,
		std::multimap<sign_t,const Functor*>& bmap,
		tmemo_t& typememo,
		deadline_t deadline)
	{
		//cerr << "instantiate_output, build so far: " << build << "\n";
		// Perform query to retrieve all outputs

		// Use memoization to speed up queries (this assumes previous queries have no side effects)
		// If a query has been made before, it will give precisely the same answer set again

		if (params.is_set(parameters::memo_queries)) {
			if (qterms.find(build) != qterms.end()) {
				// build has been memoized from previous query, skip it
				DEBUG_TRACE( cerr << "Memoization, skipping already queried: " << build << "\n" );
				return;
			} else {
				// Memoize Query (Note: we memoize failed queries too)
				qterms.insert(build);
			}
		}

		// Make query to get all output/ground variables
		answer_type answers;
		DEBUG_TRACE(cerr << "instantiate_outputs, query: " << build << "\n");
		qconstraints qc(params);
		qc.recall = mode.recall();
		int ans;
		// Optimize: don't make copy of build
		Functor q(if_id);
		q.attach_child(&build);
		try {
			ans = query(q,qc,answers);
		} catch (...) {
			q.unattach_child();
			throw;
		}
		q.unattach_child();

		DEBUG_TRACE(cerr << "instantiate_outputs, query returned: " << ans << "\n");

		for (auto sub = answers.begin(); sub != answers.end(); ++sub) {
			if (std::chrono::steady_clock::now() > deadline) throw time_out();
			DEBUG_TRACE(cerr << "sub: " << *sub << "\n");
			Functor instance = build;
			// DEBUG_TRACE(cerr << "unlifted instance:    " << instance << "\n");
			instance.subst(sub->get_expanded());
			DEBUG_TRACE(cerr << "  substituted instance: " << instance << "\n");
			if (!instance.is_ground()) {
				DEBUG_TRACE(cerr << "  non-ground, skipping instance\n");
				continue; // skip instance of not ground
			}

			// Check that each -type is instantiated, and to the right type
			auto fail_at = find_if(mode.begin(), mode.end(), [&](const Mode::pmarker& pm) {
				return pm.iotype == Mode::output && !check_type_m(*instance.get(pm.pos), pm.type, *this, typememo);
			});
			if (fail_at != mode.end()) {
				DEBUG_TRACE(cerr << "  outputs have wrong type\n");
				continue; // -types wrong
			}
			DEBUG_TRACE(cerr << "  outputs are ok\n");

			// Check that each #type is ground and of right type
			fail_at = find_if(mode.begin(), mode.end(), [&](const Mode::pmarker& pm) -> bool {
				if (pm.iotype == Mode::ground) {
					const Functor* f = instance.get(pm.pos);
					DEBUG_TRACE(cerr << "  should be ground: " << *f << "\n");
					return !f->is_ground() || !check_type_m(*f,pm.type,*this,typememo);
				} else return false;
			});
			if (fail_at != mode.end()) {
				DEBUG_TRACE(cerr << "  ground types are not ground\n");
				continue; // #types wrong
			}
			DEBUG_TRACE(cerr << "  ground types are ok\n");

			//DEBUG_TRACE(
			//if (find(botv.begin(),botv.end(),linstance) != botv.end()) {
			//	cerr << "ERROR: trying to add dupe instance: " << linstance << "\n";
			//	cerr << "botv: ";
			//	for_each(botv.begin(),botv.end(),[&](const Functor& f){ cerr << f << " "; });
			//	cerr << "\n";
			//	cerr << "insts: ";
			//	for_each(insts.begin(),insts.end(),[&](const Functor* f){ cerr << *f << " "; });
			//	cerr << "\n";
			//	exit(1);
			//});

			// ======= Passed all Tests, Add/Update now ======= //


			// linstance is to be a lifted instance
			Functor linstance = instance;

			if (!params.is_set(parameters::splitvars)) {
				// No variable splitting
				DEBUG_TRACE(cerr << "  checking if duplicate...\n");
				// Lift and update termindex
				lift(linstance, mode.begin(), mode.end(), termindex);
				// Ignore if duplicate (note: term is still indexed, but that doesn't matter as it's unique)
				if (insts.find(&linstance) != insts.end()) {
					DEBUG_TRACE(cerr << "  duplicate, ignoring\n");
					continue; 
				}
				// Check constraints
				if (params.is_set(parameters::unconditional_constraints)) {
					bool ignore = false;
					for (const auto& c : block) {
						if (constraint_satisfied(linstance,bmap,c)) {
							//std::cerr << "IGNORING " << linstance << " due to unconditional constraint "; 
							//if (!c.first.empty()) std::cerr << c.first.front();
							//std::cerr << " prevents " << c.second << "\n";
							//for (const auto& f : c.first) std::cerr << f << ",";
							//std::cerr << " prevents " << c.second << "\n";
							ignore = true;
							break;
						}
					}
					if (ignore) continue;
				}

				// New, update instantiated
				termindex.instantiate_outputs(instance,linstance,mode);
				//termindex.update_body(instance, mode.begin(), mode.end());
				// Add instance
				DEBUG_TRACE(cerr << "  ++ Adding literal instance: " << linstance << " ++\n");
				botv.push_back(std::move(linstance));
				// Add mode
				modemap.push_back(mode);
				// Memoize
				insts.insert(&botv.back());
				bmap.insert(std::make_pair(botv.back().signature(),&botv.back()));
				// Check bsize. Note: botv.size() also contains head, so botv.size-1 is no of body literals
				if (int(botv.size())-1 >= params.force_int(parameters::bsize)) throw bottom_max_length();
			} else {
				// Variable splitting
				// Note: we cannot get a dupe when variable splitting
				variable_split(instance,mode.begin(),mode,termindex,botv,modemap,insts,params);
			}
		}
	}


	void Program::instantiate_input(
		Functor& build,
		const Mode& mode,
		pmiter pm_current,
		pmiter pm_ignore,
		Termindex& termindex,
		list<Functor>& botv,
		vector<Mode>& modemap,
		std::unordered_set<Functor,functor_hash_variant,functor_req>& qterms,
		set<const Functor*,functor_ptr_less>& insts,
		std::multimap<sign_t,const Functor*>& bmap,
		tmemo_t& typememo,
		deadline_t deadline)
	{
		if (std::chrono::steady_clock::now() > deadline) throw time_out();
		if (pm_current == mode.end()) {
			// finished building inputs, build outputs
			return instantiate_output(build, mode, termindex, botv, modemap, qterms, insts,bmap,typememo,deadline);
		} else if (pm_current->iotype != Mode::input || pm_current == pm_ignore) {
			// Skip ground/output types
			return instantiate_input(build, mode, ++pm_current, pm_ignore, termindex, botv, modemap, qterms, insts,bmap,typememo,deadline);
		} else {
			// Input
			// Instantiate pm_current to all possible instantiations
			// Note: list of instantiated terms should not be resized while building terms
			for (auto i = termindex.inst_begin(); i != termindex.inst_end(); ++i) {
				DEBUG_TRACE(cerr << "checking type for: " << *i << " against ");
				DEBUG_TRACE(cerr << Functor::get_data(pm_current->type) << "\n");
				//ground not here: if (pm_current->iotype == Mode::ground && !i->is_ground()) continue; // ignore non-ground
				if (check_type_m(*i, pm_current->type, *this, typememo)) {
					// Optimization: we don't revert to previous expression
					//Functor tmp = *build.get(pm_current->pos);
					// Instantiate this input type
					DEBUG_TRACE(cerr << "instantiating_input, " << *i << ": " << build << " -> ");
					*build.get(pm_current->pos) = *i;
					DEBUG_TRACE(cerr << build << "\n");
					// Continue building next input
					instantiate_input(build, mode, std::next(pm_current), pm_ignore, termindex, botv, modemap, qterms, insts, bmap,typememo, deadline);
					// *build.get((--pm_current)->pos) = tmp;
				}
			}
		}
	}


	void Program::instantiate_one(
		const Functor& trm,
		Termindex& termindex,
		list<Functor>& botv,
		vector<Mode>& modemap,
		std::unordered_set<Functor,functor_hash_variant,functor_req>& qterms,
		set<const Functor*,functor_ptr_less>& insts,
		std::multimap<sign_t,const Functor*>& bmap,
		tmemo_t& typememo,
		deadline_t deadline)
	{
		// Target predicate can be acquired from botv.front() (head)
		assert(!botv.empty());
		const sign_t target_sign = botv.front().signature();
		// For each mode, try to put term in an input positions
		for (auto m = modev.begin(); m != modev.end(); ++m) {
			if (std::chrono::steady_clock::now() > deadline) throw time_out();
			if (m->is_head()) continue;
			// If predicate does not have any definitions, we don't use it
			const sign_t sign = m->atom().signature();
			if (builtin.find(sign) == builtin.end() && kb.find(sign) == kb.end()) {
				DEBUG_WARNING(std::cerr << "Ignoring " << *m << " as it has no predicate definition\n");
				continue;
			}
			// Check determination
			if (!determination(target_sign,m->atom().signature())) continue;
			// For each +type, instantiate it to *i
			for (auto t = m->begin(); t != m->end(); ++t) {
				if (t->iotype == Mode::input && check_type_m(trm, t->type,*this,typememo)) {
					Functor build = m->prototype();
					DEBUG_TRACE(cerr << "instantiate_one, build prototype: " << build << "\n");
					DEBUG_TRACE(cerr << "input variable is: " << *build.get(t->pos) << "\n");
					DEBUG_TRACE(cerr << "term is: " << trm << "\n");
					*build.get(t->pos) = trm; // p(X,Y,Z) -> p(a,Y,Z)
					// Make recursive call to finish other +types
					DEBUG_TRACE( cerr << "instantiate_one, build: " << build << "\n" );
					instantiate_input(build, *m, m->begin(), t, termindex, botv, modemap, qterms, insts, bmap, typememo, deadline);
				}
			}
		}
	}


	bool Program::determination(const sign_t& target, const sign_t& s)
	{
		// Check Exclude
		auto nr = excludes.equal_range(target);
		if (find_if(nr.first,nr.second,[&](const det_type::value_type& p){return p.second == s;}) != nr.second) {
			return false;
		}
		// If we have no include list, accept everything
		if (includes.empty()) return true;
		// Check include
		auto r = includes.equal_range(target);
		return find_if(r.first,r.second,[&](const det_type::value_type& p){return p.second == s;})
			!= r.second;
	}

	Functor Program::bottom(const Functor& e, std::vector<Mode>& modemap, deadline_t deadline)
	{
		// type_mem is set of types that already have been queried (for memoization)

		assert(e.head());

		// Map terms to variable indices and instantiated terms
		Termindex termindex;
		//set<id_type> instantiated, last_instantiated; // set of instantiated terms
		//map<Functor,id_type> term2var; // map terms to variables
		// Vector of bottom head and body literals
		// Note: botv must be list since insts points to elements inside botv 
		// (vector resizes and invalidates pointers)
		std::list<Functor> botv;
		// Set of pointers to botv elements (more efficient redundancy checking)
		std::set<const Functor*,functor_ptr_less> insts;
		std::multimap<sign_t,const Functor*> bmap;
		tmemo_t typememo;

		// Set of terms that already have been queried (for memoization)
		std::unordered_set<Functor,functor_hash_variant,functor_req> qterms;

		// Find modeh that corresponds with head(e)
		// cerr << "in bottom, modev.size(): " << modev.size() << "\n";
		// cerr << "example: " << e << "\n";
		Functor head;
		auto m = modev.begin();
		for (; m != modev.end(); ++m) {
			if (std::chrono::steady_clock::now() > deadline) throw time_out();
			// cerr << "mode: " << *m << "\n";
			if (!m->is_head()) continue; // only look for modeh
			//auto mc = modehc.find(*m);
			//if (mc != modehc.end() && mc->second >= m->recall()) continue; // this modeh cannot be used anymore
			if (m->atom().id() == e.head()->id() && m->atom().arity() == e.head()->arity()) {
				head = m->prototype();
				Subst subs;
				DEBUG_TRACE(cerr << "Bottom head: " << head << " -> " << *e.head() << "\n");
				if (match(&head, e.head(), subs)) {
					head.subst(subs.get_expanded());
					DEBUG_TRACE(cerr << " success: " << head << " with " << subs << "\n");
					// head will be ground since example was
					// Note: when this is relaxed, check that #types are ground
					assert(head.is_ground()); 
					// Check head type
					auto i = m->begin();
					for (; i != m->end(); ++i) {
						if (!check_type_m(*head.get(i->pos), i->type, *this, typememo)) {
							//cerr << " incorrect type: " << *head.get(i->pos) << " is not " << Functor::get_data(i->type) << "\n";
							break;
						}
					}
					if (i == m->end()) {
						// Found a corresponding modeh
						break; // done
					}
					// else: try to find another modeh
				}
			}
		}
		// Fail if we didn't find any head modes (return NIL functor)
		if (m == modev.end()) return Functor();

		assert(head.is_fact());

		// Get signature of target predicate
		const sign_t target_sign = m->atom().signature();

		// Lift head (indexes terms)
		// cerr << "Lifting head: " << head << " -> ";
		if (!params.is_set(parameters::splitvars)) {
			// no variable splitting
			Functor lhead = head; // lifted head
			lift(lhead, m->begin(), m->end(), termindex);
			// Update instantiated
			termindex.instantiate_inputs(head,lhead,*m);
			// Store corresponding mode
			modemap.push_back(*m);
			// Add to bottom vector
			botv.push_front(lhead);
			// memoize it (to get log(n) lookup instead of n for traversing botv)
			insts.insert(&botv.front());
		} else {
			// variable split (instantiaties terms, adds equalities and lifted head)
			variable_split_head(head,m->begin(),*m,termindex,botv, modemap, insts);
		}

		DEBUG_TRACE(cerr << "Bottom head: " << botv.front() << "\n");

		/* TODO: have a "instantiate all ground terms" option, which instantiates all (intput AND output) variables
		to ALL possible ground terms => no need for mode declarations */
		// Create body literals iteration wise
		try {
			const auto max_iterations = params.force_int(parameters::biterations);
			if (max_iterations >= 0) {
				// Create output-only literals
				// cerr << "Bottom(), at iteration: 0 (create output-only)\n";
				for (auto m = modev.begin(); m != modev.end(); ++m) {
					if (std::chrono::steady_clock::now() > deadline) throw time_out();
					if (m->is_head()) continue; // skip modeh
					// Check determination
					if (!determination(target_sign,make_pair(m->atom().id(),m->atom().arity()))) continue;
					if (m->input_begin() == m->input_end()) {
						// Mode m has no input types
						Functor build = m->prototype();
						instantiate_output(build,*m,termindex,botv,modemap,qterms,insts,bmap,typememo,deadline);
					}
				}
			}
			for (int iteration = 1; iteration <= max_iterations; ++iteration) {
				if (std::chrono::steady_clock::now() > deadline) throw time_out();
				// Go through all terms instantiated during last iteration
				termindex.update_iteration(); // move previous instantiations into current
				//cerr << "Bottom(), at iteration: " << iteration << " (botv: " << botv.size() << ")\n";
				for (auto i = termindex.inst_begin(); i != termindex.inst_end(); ++i) {
					// Create terms that contain *i as input
					instantiate_one(*i,termindex,botv,modemap,qterms,insts,bmap,typememo,deadline);
				}
			}
		} catch (bottom_max_length) {
			DEBUG_WARNING(cerr << "warning: bottom max length " << params.force_int(parameters::bsize) << " reached\n");
		}
		// cerr << "Termindex: " << termindex << "\n";

		// Create bottom clause from vector
		Functor bot;
		auto l = botv.begin();
		assert(l != botv.end());
		bot.head(l->copy());
		Functor* last = nullptr;
		DEBUG_TRACE(cerr << "Making Bottom from " << (botv.size()-1) << " literals\n");
		for (++l; l != botv.end(); ++l) {
			Functor* el = l->copy();
			bot.body_push_back( el, last );
			last = el;
		}

		// Note: no need to deallocate insts as it statically points to elements in botv

		return bot;
	}
}

